import hashlib
import json
import argparse
import os
from typing import List
import numpy as np
import time

def hash_data(data: str) -> str:
    return hashlib.sha256(data.encode("utf-8")).hexdigest()

def construct_flat_merkle_tree(leaves: List[str]) -> List[str]:
    n = len(leaves)
    if n == 0:
        raise ValueError("Empty leaves")
    
    leaf_count = 1 << (n - 1).bit_length()
    padded_leaves = leaves + [leaves[-1]] * (leaf_count - n)

    total_nodes = 2 * leaf_count - 1
    tree = [None] * total_nodes

    for i in range(leaf_count):
        tree[leaf_count - 1 + i] = hash_data(padded_leaves[i])

    for i in range(leaf_count - 2, -1, -1):
        left = tree[2 * i + 1]
        right = tree[2 * i + 2]
        tree[i] = hash_data(left + right)

    return tree

def save_tree(tree: List[str], output_file: str):
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(tree, f, indent=2)

def load_tree(input_file: str) -> List[str]:
    with open(input_file, "r", encoding="utf-8") as f:
        return json.load(f)

def print_merkle_tree_flat(tree: List[str]):
    n = (len(tree) + 1) // 2
    levels = []
    while n:
        levels.append(n)
        n //= 2
    start = 0
    for depth, width in enumerate(reversed(levels)):
        end = start + width
        print(f"Level {depth}:")
        for h in tree[start:end]:
            print(f"  {h}")
        print()
        start = end

def main():
    parser = argparse.ArgumentParser(description="Merkle Tree Tool")
    parser.add_argument("--input_json", type=str, help="Input JSON file with list of strings")
    parser.add_argument("--input_npy", type=str, help="Input NPY file with float32 vectors")
    parser.add_argument("--save_tree", type=str, help="Path to save Merkle tree JSON")
    parser.add_argument("--load_tree", type=str, help="Path to load Merkle tree JSON")
    args = parser.parse_args()

    if args.load_tree:
        tree = load_tree(args.load_tree)
        print("[✓] Loaded Merkle tree from file.")
        # print_merkle_tree_flat(tree)
    elif args.input_json or args.input_npy:
        if args.input_json:
            with open(args.input_json, "r", encoding="utf-8") as f:
                data = json.load(f)
            if not isinstance(data, list) or not all(isinstance(x, str) for x in data):
                raise ValueError("JSON file must contain a list of strings.")
        elif args.input_npy:
            vectors = np.load(args.input_npy)
            if len(vectors.shape) != 2 or vectors.shape[1] != 768:
                raise ValueError("NPY file must contain [n, 768] float32 vectors.")
            data = [",".join(map(str, vec.tolist())) for vec in vectors]

        start_time = time.time()
        tree = construct_flat_merkle_tree(data)
        elapsed = time.time() - start_time
        print(f"[✓] Constructed Merkle tree in {elapsed:.2f} seconds.")

        if args.save_tree:
            save_tree(tree, args.save_tree)
            print(f"[✓] Tree saved to {args.save_tree}")
    else:
        raise ValueError("Please provide either --input_json, --input_npy or --load_tree.")

if __name__ == '__main__':
    main()
